# coding:utf-8

"""
main
Created on 2016/12/8 15:23
@author: GuoYufu
@group : OceanHorn
@contact: OceanHorn@163.com
"""

import Image_loader
import Image_util
import Network

if __name__ == "__main__":
    test_samples_source, test_labels_source = Image_loader.Image_loader.load_samples_and_labels("test_32x32")
    train_samples_source, train_labels_source = Image_loader.Image_loader.load_samples_and_labels("train_32x32")
    # extra_samples_source, extra_labels_source = Image_loader.load_samples_and_labels("extra_32x32")

    test_samples, test_labels = Image_util.Image_util.reform(test_samples_source, test_labels_source)
    train_samples, train_labels = Image_util.Image_util.reform(train_samples_source, train_labels_source)

    train_samples = Image_util.Image_util.normalize(train_samples)
    test_samples = Image_util.Image_util.normalize(test_samples)

    net = Network.Network(num_hidden=128, batch_size=1000, test_batch_size=500, image_size=32, num_labels=10,num_channels=1)
    net.define_graph()
    net.run(train_samples=train_samples, train_labels=train_labels, test_samples=test_samples, test_labels=test_labels)